In [17]:
import pickle
import jax

import matplotlib.pyplot as plt
import jax.numpy as jnp
import tensorflow_probability.substrates.jax as tfp
from scipy.stats import gaussian_kde
import plotly.express as px
import pandas as pd
import pickle
tfd = tfp.distributions
import plotly
plotly.offline.init_notebook_mode()
In [18]:
x = jnp.linspace(8,12,5000)
all_pdf = []
In [19]:
with open('./results_data/weibull_Ajax','rb') as f:
    variational_ajax  = pickle.load(f)
In [20]:
di = {"rate":x}
posterior_ajax = jnp.exp(variational_ajax.log_prob(di))
all_pdf.append(posterior_ajax)
In [21]:
import numpy as np
with open('./results_data/MCMC_Blackjax','rb') as f:
    mcmc_model  = pickle.load(f)
mcmc_samples= np.array(mcmc_model.position[300:])
In [22]:
kde_black = gaussian_kde(mcmc_samples.reshape((-1,)))
pdf_black = kde_black(x)
all_pdf.append(pdf_black)
pdf_black.shape
Out[22]:
(5000,)
In [23]:
all_label = ["Ajax VI estimate"]*x.shape[0] + ["Blackjax RMH estimate"]*x.shape[0]
x_repeated = jnp.tile(x,2)
all_pdf = jnp.array(all_pdf).reshape((-1))
to_df = {
    "theta":x_repeated,
    "PDF":all_pdf,
    "label": all_label

}
df = pd.DataFrame(to_df)
fig = px.line(to_df,"theta","PDF",color="label",title="Weibull Poisson results") 
fig.show()
fig.write_html("weibull_poisson_result_plotly.html")
In [22]:
import plotly
plotly.offline.init_notebook_mode()